import argparse
import time
import traceback

import dgl

import networkx as nx
import numpy as np
import torch
from dataloader import (
    MultiBodyGraphCollator,
    MultiBodyTestDataset,
    MultiBodyTrainDataset,
    MultiBodyValidDataset,
)
from models import InteractionNet, MLP, PrepareLayer
from torch.utils.data import DataLoader
from utils import make_video


def train(
    optimizer, loss_fn, reg_fn, model, prep, dataloader, lambda_reg, device
):
    total_loss = 0
    model.train()
    for i, (graph_batch, data_batch, label_batch) in enumerate(dataloader):
        graph_batch = graph_batch.to(device)
        data_batch = data_batch.to(device)
        label_batch = label_batch.to(device)
        optimizer.zero_grad()
        node_feat, edge_feat = prep(graph_batch, data_batch)
        dummy_relation = torch.zeros(edge_feat.shape[0], 1).float().to(device)
        dummy_global = torch.zeros(node_feat.shape[0], 1).float().to(device)
        v_pred, out_e = model(
            graph_batch,
            node_feat[:, 3:5].float(),
            edge_feat.float(),
            dummy_global,
            dummy_relation,
        )
        loss = loss_fn(v_pred, label_batch)
        total_loss += float(loss)
        zero_target = torch.zeros_like(out_e)
        loss = loss + lambda_reg * reg_fn(out_e, zero_target)
        reg_loss = 0
        for param in model.parameters():
            reg_loss = reg_loss + lambda_reg * reg_fn(
                param, torch.zeros_like(param).float().to(device)
            )
        loss = loss + reg_loss
        loss.backward()
        optimizer.step()
    return total_loss / (i + 1)


# One step evaluation


def eval(loss_fn, model, prep, dataloader, device):
    total_loss = 0
    model.eval()
    for i, (graph_batch, data_batch, label_batch) in enumerate(dataloader):
        graph_batch = graph_batch.to(device)
        data_batch = data_batch.to(device)
        label_batch = label_batch.to(device)
        node_feat, edge_feat = prep(graph_batch, data_batch)
        dummy_relation = torch.zeros(edge_feat.shape[0], 1).float().to(device)
        dummy_global = torch.zeros(node_feat.shape[0], 1).float().to(device)
        v_pred, _ = model(
            graph_batch,
            node_feat[:, 3:5].float(),
            edge_feat.float(),
            dummy_global,
            dummy_relation,
        )
        loss = loss_fn(v_pred, label_batch)
        total_loss += float(loss)
    return total_loss / (i + 1)


# Rollout Evaluation based in initial state
# Need to integrate


def eval_rollout(model, prep, initial_frame, n_object, device):
    current_frame = initial_frame.to(device)
    base_graph = nx.complete_graph(n_object)
    graph = dgl.from_networkx(base_graph).to(device)
    pos_buffer = []
    model.eval()
    for step in range(100):
        node_feats, edge_feats = prep(graph, current_frame)
        dummy_relation = torch.zeros(edge_feats.shape[0], 1).float().to(device)
        dummy_global = torch.zeros(node_feats.shape[0], 1).float().to(device)
        v_pred, _ = model(
            graph,
            node_feats[:, 3:5].float(),
            edge_feats.float(),
            dummy_global,
            dummy_relation,
        )
        current_frame[:, [1, 2]] += v_pred * 0.001
        current_frame[:, 3:5] = v_pred
        pos_buffer.append(current_frame[:, [1, 2]].cpu().numpy())
    pos_buffer = np.vstack(pos_buffer).reshape(100, n_object, -1)
    make_video(pos_buffer, "video_model.mp4")


if __name__ == "__main__":
    argparser = argparse.ArgumentParser()
    argparser.add_argument(
        "--lr", type=float, default=0.001, help="learning rate"
    )
    argparser.add_argument(
        "--epochs", type=int, default=40000, help="Number of epochs in training"
    )
    argparser.add_argument(
        "--lambda_reg", type=float, default=0.001, help="regularization weight"
    )
    argparser.add_argument(
        "--gpu", type=int, default=-1, help="gpu device code, -1 means cpu"
    )
    argparser.add_argument(
        "--batch_size", type=int, default=100, help="size of each mini batch"
    )
    argparser.add_argument(
        "--num_workers",
        type=int,
        default=0,
        help="number of workers for dataloading",
    )
    argparser.add_argument(
        "--visualize",
        action="store_true",
        default=False,
        help="Whether enable trajectory rollout mode for visualization",
    )
    args = argparser.parse_args()

    # Select Device to be CPU or GPU
    if args.gpu != -1:
        device = torch.device("cuda:{}".format(args.gpu))
    else:
        device = torch.device("cpu")

    train_data = MultiBodyTrainDataset()
    valid_data = MultiBodyValidDataset()
    test_data = MultiBodyTestDataset()
    collator = MultiBodyGraphCollator(train_data.n_particles)

    train_dataloader = DataLoader(
        train_data,
        args.batch_size,
        True,
        collate_fn=collator,
        num_workers=args.num_workers,
    )
    valid_dataloader = DataLoader(
        valid_data,
        args.batch_size,
        True,
        collate_fn=collator,
        num_workers=args.num_workers,
    )
    test_full_dataloader = DataLoader(
        test_data,
        args.batch_size,
        True,
        collate_fn=collator,
        num_workers=args.num_workers,
    )

    node_feats = 5
    stat = {
        "median": torch.from_numpy(train_data.stat_median).to(device),
        "max": torch.from_numpy(train_data.stat_max).to(device),
        "min": torch.from_numpy(train_data.stat_min).to(device),
    }
    print(
        "Weight: ",
        train_data.stat_median[0],
        train_data.stat_max[0],
        train_data.stat_min[0],
    )
    print(
        "Position: ",
        train_data.stat_median[[1, 2]],
        train_data.stat_max[[1, 2]],
        train_data.stat_min[[1, 2]],
    )
    print(
        "Velocity: ",
        train_data.stat_median[[3, 4]],
        train_data.stat_max[[3, 4]],
        train_data.stat_min[[3, 4]],
    )

    prepare_layer = PrepareLayer(node_feats, stat).to(device)
    interaction_net = InteractionNet(node_feats, stat).to(device)
    print(interaction_net)
    optimizer = torch.optim.Adam(interaction_net.parameters(), lr=args.lr)
    state_dict = interaction_net.state_dict()

    loss_fn = torch.nn.MSELoss()
    reg_fn = torch.nn.MSELoss(reduction="sum")
    try:
        for e in range(args.epochs):
            last_t = time.time()
            loss = train(
                optimizer,
                loss_fn,
                reg_fn,
                interaction_net,
                prepare_layer,
                train_dataloader,
                args.lambda_reg,
                device,
            )
            print("Epoch time: ", time.time() - last_t)
            if e % 1 == 0:
                valid_loss = eval(
                    loss_fn,
                    interaction_net,
                    prepare_layer,
                    valid_dataloader,
                    device,
                )
                test_full_loss = eval(
                    loss_fn,
                    interaction_net,
                    prepare_layer,
                    test_full_dataloader,
                    device,
                )
                print(
                    "Epoch: {}.Loss: Valid: {} Full: {}".format(
                        e, valid_loss, test_full_loss
                    )
                )
    except:
        traceback.print_exc()
    finally:
        if args.visualize:
            eval_rollout(
                interaction_net,
                prepare_layer,
                test_data.first_frame,
                test_data.n_particles,
                device,
            )
            make_video(test_data.test_traj[:100, :, [1, 2]], "video_truth.mp4")
